-
-
Notifications
You must be signed in to change notification settings - Fork 816
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fused Cross Entropy Loss #1601
Fused Cross Entropy Loss #1601
Conversation
mgmalek's implementation might have numerical issues when using half precision dtype. |
haha, yeah, that was probably me reporting the loss curve differences out on twitter. Do you have any insights on the best way to fix the kernels? standard CEL uses bfloat16 as well iirc. |
The problem is the lm_head weight (proj_weight), not CEL. By default bf16 matmul use fp32 accumulator (along the GEMM K axis). Since matmul is performed tile by tile in M&N axis (MxK@KxN matmul), registers are used as accumulator, without the need of a fp32 global memory buffer. However, here we perform chunking along the GEMM K axis (DxT@TxV matmul)… A simple fix can be using fp32 grad_proj_weight, but that will incur a significant memory overhead. Can you check the loss curve with fp32 grad_proj_weight? |
gave your recommendation a try, but changing grad_proj_weight to fp32 didn't seem to help much. |
Ah Oh, now we need more investigation… |
It seems that the fused version has a different loss & grad_norm from the first step? That looks strange… |
closing now that Liger is available |
pytorch/pytorch#124480 served as the impetus for getting this integrated.
Adapted the code from @mgmalek's https://github.com/mgmalek/efficient_cross_entropy/blob/main/modules.py so that it properly handles causal/next-token prediction.